class Solution:
    def numTrees(self, n: int) -> int:
        if n <=2:
            return n
        dp = [0] * (n+1)
        dp[1]=1
        dp[2]=2
        for i in range(3,n+1):
            dp[i] = 2*dp[i-1] + 2**(i-3)*dp[i-2]

        return dp[n]
print(Solution().numTrees(3))
print(Solution().numTrees(4))